# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Union

from mmengine.runner import IterBasedTrainLoop
from torch.utils.data import DataLoader


class TrainLoop(IterBasedTrainLoop):
    def __init__(
        self,
        runner,
        dataloader: Union[DataLoader, Dict],
        max_iters: Optional[int] = None,
        max_epochs: Union[int, float] = None,
        **kwargs,
    ) -> None:
        if max_iters is None and max_epochs is None:
            raise RuntimeError(
                "Please specify the `max_iters` or " "`max_epochs` in `train_cfg`."
            )
        elif max_iters is not None and max_epochs is not None:
            raise RuntimeError(
                "Only one of `max_iters` or `max_epochs` can " "exist in `train_cfg`."
            )
        else:
            if max_iters is not None:
                iters = int(max_iters)
                assert iters == max_iters, (
                    "`max_iters` should be a integer " f"number, but get {max_iters}"
                )
            elif max_epochs is not None:
                if isinstance(dataloader, dict):
                    diff_rank_seed = runner._randomness_cfg.get("diff_rank_seed", False)
                    dataloader = runner.build_dataloader(
                        dataloader, seed=runner.seed, diff_rank_seed=diff_rank_seed
                    )
                iters = max_epochs * len(dataloader)
            else:
                raise NotImplementedError
        super().__init__(
            runner=runner, dataloader=dataloader, max_iters=iters, **kwargs
        )
